Skip to main content

Memoization

As the name suggests, memoization is a technique to take memo of the results to reduce computation. This is particularly useful because of the overlapping subproblems in dynamic programming.

Overlapping Subproblems in Fibonacci Sequence

In Fibonacci sequence, if we try to expand the larger of the two terms of the nn-th Fibonacci number, we will see that f(n2)f(n - 2) is used twice.

f(n)=f(n1)+f(n2)=f(n2)+f(n3)+f(n2)\begin{aligned} f(n) &= f(n - 1) + f(n - 2) \\ &= \underline{f(n - 2)} + f(n - 3) + \underline{f(n - 2)} \end{aligned}

Because the recursion tree of the Fibonacci sequence problem is a binary tree, the number of recursive calls is O(2n)O(2^n), which is exponential. So as nn grows, the number of overlapping subproblems and thus repeated computation also grows exponentially.

Time complexity grows with O(2n)O(2^n). Yellow nodes represent repeated computation on overlapping subproblems.

An exponential growth in time complexity is really bad, so we need to find a way to optimize it, and memoization is one of the ways to do so.

Applying Memoization

As mentioned before, memoization is simply making memo of the results. In other words, we can store the results in memory when we first compute it, and use it when we need it again. This is generally done by using an array or a hash table.

For an implementation in the Fibonacci sequence problem, we will use an array dp\text{dp} to store the results.

We will first check whether we have already stored the result in dp\text{dp} before doing any computation, and if we have, we will simply return the result. Otherwise, compute the result and store it in dp\text{dp}.

n-th Fibonacci Number with Memoization
// initialize the memoization array
let dp = [null, null, ..., null] of size n + 1
let dp[0] = 0
let dp[1] = 1

// fibonacci function
function fib(n: int) -> int {
// check if we have the memoized result
if dp[n] is not null { return dp[n] }

// compute the result
dp[n] = fib(n - 1) + fib(n - 2)
return dp[n]
}

If we try to visualize the recursion tree of the above algorithm, we can see that it only has nn nodes, and each subproblem is only computed once. This reduces the time complexity to only O(n)O(n).

Time complexity grows with O(n)O(n). The repeated computations are eliminated by directly taking stored results in dp\text{dp}.

Multiple Calls Optimization

In the previous implementation, we assumed only one call to the function with a known nn.

If we want to call the function multiple times, e.g. f(7)f(7), then f(13)f(13), then f(10)f(10), we will recompute the results of f(0)f(0) to f(7)f(7) three times in total. This is because we reinitialized dp\text{dp} to all nulls every time we run the algorithm.

To further optimize this, we make dp\text{dp} an array shared by all calls to the function and grow the array as needed, which reduces the average time complexity of the solution if multiple calls are made.

n-th Fibonacci Number with Memoization - Additional Optimization
// initialize the memoization array
let dp = [0, 1]

// fibonacci function
function fib(n: int) -> int {
// grow the array if needed
if dp.length <= n { dp.resize(n + 1, filled with null) }

// check if we have the memoized result
if dp[n] is not null: return dp[n]

// compute the result
dp[n] = fib(n - 1) + fib(n - 2)
return dp[n]
}
Checkpoint

In memoization, we can make extra use of what to reduce the number of computations on overlapping subproblems?

Mathematical expressions
Function calls
Memory space
Global variables

Implementation

In Python, the implementation is almost the same as the pseudocode above, so it can be easily implemented as follows.

n-th Fibonacci Number with Memoization
# initialize the memoization list
dp = [0, 1]

# fibonacci function
def fib(n):
# grow the list if needed
if len(dp) <= n:
dp.extend([None] * (n + 1 - len(dp)))

# check if we have the memoized result
if dp[n] is not None:
return dp[n]

# compute the result
dp[n] = fib(n - 1) + fib(n - 2)
return dp[n]

# main function
def main():
# gets input and prints the result
n = int(input())
print(fib(n))

if __name__ == "__main__":
main()